/*
 * Copyright (c) 2021, LiWangQian<liwangqian@huawei.com> All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without modification,
 * are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this list of
 *    conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice, this list
 *    of conditions and the following disclaimer in the documentation and/or other materials
 *    provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its contributors may be used
 *    to endorse or promote products derived from this software without specific prior written
 *    permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
 * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include <catch2/catch.hpp>
#include "refcount.h"
#include <thread>

struct TestObject {
    uint32_t id;
};

static inline
bool TestObjectCtor(void *memPtr, void *ctx)
{
    if (memPtr == nullptr) {
        return false;
    }

    auto obj = static_cast<TestObject *>(memPtr);
    obj->id = 1;
    return true;
}

static inline
void TestObjectDtor(TestObject **obj)
{
    if (obj == nullptr || *obj == nullptr) {
        return;
    }

    RCHandle handle = RefCountGetHandle(*obj);
    *obj = nullptr;
    RefCountUnref(&handle);
}

#define TestObjAutoPtr RAII(TestObject *, TestObjectDtor)

static inline TestObject *TestObjectCreate()
{
    static Allocator allocator = *GetDefaultAllocator();
    allocator.ctor = TestObjectCtor;
    RCHandle rcHandle = RefCountCreate(sizeof(TestObject), &allocator);
    return (TestObject *)RefCountGetData(rcHandle);
}

TEST_CASE("Test RefCount Methods")
{
    TestObjAutoPtr ptr = TestObjectCreate();
    REQUIRE(ptr != nullptr);
    REQUIRE(ptr->id == 1);
    RCHandle rcHandle = RefCountGetHandle(ptr);

    REQUIRE(RefCountIsValid(rcHandle));
    REQUIRE(RefCountGetCount(rcHandle) == 1);

    REQUIRE(RefCountIsValid(rcHandle));
    auto count = RefCountGetCount(rcHandle);
    auto newRcHandle = RefCountRef(rcHandle);
    REQUIRE(RefCountGetCount(newRcHandle) == (count + 1));
    RefCountUnref(&newRcHandle);
    REQUIRE(!RefCountIsValid(newRcHandle));
}

TEST_CASE("Test RefCount in Multi-Threading context")
{
    Allocator allocator = *GetDefaultAllocator();
    allocator.ctor = TestObjectCtor;

    RCHandle rcHandle = RefCountCreate(sizeof(TestObject), &allocator);

    REQUIRE(RefCountIsValid(rcHandle));
    REQUIRE(RefCountGetCount(rcHandle) == 1);

    std::vector<std::thread *> threads;
    for (size_t i = 0; i < 20; i++) {
        threads.push_back(new std::thread{[=](RCHandle handle){
            REQUIRE(RefCountGetCount(handle) > 1);
            RefCountUnref(&handle);
        }, RefCountRef(rcHandle)});
    }
    
    for (auto t : threads) {
        t->join();
        delete t;
    }

    REQUIRE(RefCountGetCount(rcHandle) == 1);
    RefCountUnref(&rcHandle);
    REQUIRE(not RefCountIsValid(rcHandle));
}
